{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7bd5137ff0b2"
   },
   "source": [
    "##### Copyright 2021 The Cirq Developers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "id": "906e07f6e562"
   },
   "outputs": [],
   "source": [
    "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "# https://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ytiXnAqTUBrB"
   },
   "source": [
    "# Fourier Checking Problem"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "view-in-github"
   },
   "source": [
    "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
    "  <td>\n",
    "    <a target=\"_blank\" href=\"https://quantumai.google/cirq/experiments/fourier_checking\"><img src=\"https://quantumai.google/site-assets/images/buttons/quantumai_logo_1x.png\" />View on QuantumAI</a>\n",
    "  </td>\n",
    "  <td>\n",
    "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/quantumlib/Cirq/blob/main/docs/experiments/fourier_checking.ipynb\"><img src=\"https://quantumai.google/site-assets/images/buttons/colab_logo_1x.png\" />Run in Google Colab</a>\n",
    "  </td>\n",
    "  <td>\n",
    "    <a target=\"_blank\" href=\"https://github.com/quantumlib/Cirq/blob/main/docs/experiments/fourier_checking.ipynb\"><img src=\"https://quantumai.google/site-assets/images/buttons/github_logo_1x.png\" />View source on GitHub</a>\n",
    "  </td>\n",
    "  <td>\n",
    "    <a href=\"https://storage.googleapis.com/tensorflow_docs/Cirq/docs/experiments/fourier_checking.ipynb\"><img src=\"https://quantumai.google/site-assets/images/buttons/download_icon_1x.png\" />Download notebook</a>\n",
    "  </td>\n",
    "</table>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-7UPLcFfFzVv"
   },
   "outputs": [],
   "source": [
    "# Initial setup to install Cirq and set up dependencies for the tutorial.\n",
    "try:\n",
    "    import cirq\n",
    "except:\n",
    "    print(\"installing cirq...\")\n",
    "    !pip install --quiet cirq\n",
    "    print(\"installed cirq.\")\n",
    "    import cirq\n",
    "\n",
    "from typing import Sequence\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "\n",
    "# Sets a seed for deterministic results.  Uncomment for random results each run.\n",
    "np.random.seed(2021)\n",
    "np.set_printoptions(precision=3, suppress=True, linewidth=200)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2k_4KaHyDMSh"
   },
   "source": [
    "# Introduction\n",
    "\n",
    "\n",
    "In past decades, there is ample evidence suggesting that quantum computers can be exponentially more powerful in solving certain computational tasks than their classical couterparts. The *black-box* or *query* model, such as Grover’s search, Deutsch-Jozsa’s algorithm, etc., offers a concrete setting to show such exponential speedups. Normally, one provides \"black-box access\" to a function $f$, meaning that the quantum algorithm can apply a unitary\n",
    "transformation that maps basis states of the form $|x, y \\rangle$ to to basis states of the form $|x, y \\oplus f(x)\\rangle $ or $|x\\rangle$ to $(-1)^{f(x)} |x\\rangle$ if $f$ is Boolean. Then, a natural question is asked:\n",
    "\n",
    "> What is the maximal possible separation between quantum and classical query complexities?\n",
    "\n",
    "For example, could there be a function of $N$ bits with a quantum query\n",
    "complexity of 1, but a classical randomized query complexity of $\\Omega(\\sqrt{N})$ or $\\Omega(N)$?\n",
    "Specifically, Buhrman et al. [[1]](https://www.sciencedirect.com/science/article/pii/S030439750100144X) from 2002 asked whether there is any\n",
    "property of $N-$bit strings that exhibits a **“maximal”** separation: that is, one that requires $\\Omega(N)$\n",
    "queries to test classically, but only O (1) quantumly.\n",
    "\n",
    "**Fourier Checking** is a problem that provides a separation between quantum  and classical computers -- $O(1)$ VS $\\tilde{\\Omega}(\\sqrt{N})$, which can be proved as optimal. Currently, it only has theoretical importance - but, as it falls into the category of small quantum algorithms, it can be used to demonstrate query complexity and oracle synthesis in Cirq.\n",
    "\n",
    "Goal of this notebook is to introduce:\n",
    "\n",
    "1. What is Forrelation and the Fourier Checking problem and why we are interested in it?\n",
    "2. What is bounded-error quantum polynomial time (BQP) and why does the Fourier Checking problem belong to it?\n",
    "3. How to implement the Fourier Checking algorithm and an oracle function in Cirq?\n",
    "\n",
    "We won't include the formal proofs and argument. However, we do give sketches of the derivation for intuition and encourage the reader to check the corresponding lemmas and theorems in the original paper.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "hq6-obwCiEXX"
   },
   "source": [
    "# Preliminary\n",
    "\n",
    "Before we present the Fourier Checking problem, three preliminary concepts, 1) BPP and BQP 2) Fourier Transform over $Z_2^n$ and 3) Forrelation, are introduced first.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HNCLXihE01A7"
   },
   "source": [
    "## Bounded-error Probabilistic Polynomial time (BPP)\n",
    "\n",
    "In computational complexity theory, [bounded-error probabilistic polynomial time (BPP)](https://en.wikipedia.org/wiki/BPP_(complexity)) is the class of decision problems solvable by a [probabilistic Turing machine](https://en.wikipedia.org/wiki/Probabilistic_Turing_machine) in polynomial time with an error probability bounded away from 1/3 for all instances:\n",
    "\n",
    "| &nbsp; &nbsp;&nbsp;&nbsp;&nbsp; Anwser Provided  <br />  <br /> Correct Anwser| Yes <br /> <br /> &nbsp; | No <br /> <br /> &nbsp;|\n",
    "|:---------------|------|-------|\n",
    "| Yes           | $\\geq$ 2/3 | $\\leq$ 1/3 |\n",
    "| No            | $\\leq$ 1/3 | $\\geq$ 2/3 |\n",
    "\n",
    "The choice of 1/3 in the definition is arbitrary. It can be any constant between 0 and ​1/2 (exclusive) and the set BPP will be unchanged."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Mj1IjP3TWVn0"
   },
   "source": [
    "To understand this better, let's consider a simple example.\n",
    "\n",
    "**Problem**: Supposing we have a boolean function $f: \\{0, 1\\}^n \\rightarrow\\{-1, 1\\}$. The function is drawn either from distribution $\\mathcal{B}$ or $\\mathcal{C}$. Under the distribution $\\mathcal{B}$, the function $f$ is balanced, i.e., $\\sum_{x} f(x) = 0$. Under the distribution $\\mathcal{C}$, the function $f$ is constant, i.e., function is always equal to $+1$ or $-1$. The problem is to accept the function if it is drawn from $\\mathcal{B}$ or reject it otherwise.\n",
    "\n",
    "For example, consider the case that $n=2$. There are only four possible functions listed in following table\n",
    "\n",
    "|x   | f_0(x)| f_1(x)| f_2(x)| f_3(x)|\n",
    "|----|----:|----:|----:|----:|\n",
    "| 0  | +1 | -1 | +1 | -1 |\n",
    "| 1  | -1 | +1 | +1 | -1 |\n",
    "\n",
    "Clearly, $f_0$ and $f_1$ belongs to $\\mathcal{B}$ and $f_2$ and $f_3$ belongs to $\\mathcal{C}$.\n",
    "\n",
    "**Deterministic Algorithm**:  We evaluate the outputs of function of $2^{n-1}+1$ different inputs. If the results contain both $+1$ and $-1$ value, the function must be drawn from distribution $\\mathcal{B}$. Otherwise, it must come from $\\mathcal{C}$.\n",
    "\n",
    "Remembering that the function is guaranteed to be either balanced or constant, not somewhere in between. So above algorithm is guaranteed to be always correct. However, the query complexity is $O(N)$, denoting $N=2^n$. Here we are interested in the query complexity instead of computation complexity. Namely, how many times we have to evaluate $f$ -- which you can imagine being a very costly function. If the bounded-error probability is acceptable for solving the problem, we can achieve better algorithm in terms of query complexity.\n",
    "\n",
    "**Randomized Algorithm**: Randomly select $K$ different inputs and evaluate the corresponding outputs. If both $+1$ and $-1$ are observed, we accept it. Otherwise, it reject it.\n",
    "\n",
    "The complexity of algorithm depends on the choice of $K$. For this problem, selecting fixed number is sufficient, i.e. $O(1)$ query complexity. It is worth to remark that $O(1)$ means no matter the fixed number is independent of the input size of function $N$. Consider the $K=2$ case, it is not hard to establish the following confusion matrix regardless of $N$:\n",
    "\n",
    "| Randomized Algorithm (K=2):  | Accept  | Reject |\n",
    "|---------------|---------|-------|\n",
    "| Drawn from $\\mathcal{B}$ | 1/2 | 1/2  |\n",
    "| Drawn from $\\mathcal{C}$ | 0 | 1 |\n",
    "\n",
    "This is not sufficient to solve it. However, if we select $K>2$ entries, the  probability of correctness will boost. Let's use the code to exam it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "DzkklbuNWWEi"
   },
   "outputs": [],
   "source": [
    "def gen_balanced_function(N: int) -> np.ndarray:\n",
    "    \"\"\"Generates a balanced function for N bits.\n",
    "\n",
    "    Creates a function 𝑓:{0,1}^N → {−1,1}\n",
    "    where f(x)=-1 for half of the inputs and f(x)=1 for the other half.\n",
    "\n",
    "    Returns:\n",
    "       the function as represented by a 1-d numpy array of size N\n",
    "    \"\"\"\n",
    "    half_size = N // 2\n",
    "    f = np.ones(N)\n",
    "    flip_loc = np.random.permutation(N)[:half_size]\n",
    "    f[flip_loc] = -1\n",
    "    return f\n",
    "\n",
    "\n",
    "def gen_constant_function(N: int) -> np.ndarray:\n",
    "    \"\"\"Generates a constant  function for N bits.\n",
    "\n",
    "    Creates a function 𝑓:{0,1}^𝑛 → {−1,1}\n",
    "    where f(x)=c for all inputs.\n",
    "\n",
    "    c is randomly chosen as either -1 or 1, but, once chosen,\n",
    "    is constant for all values of x.\n",
    "\n",
    "    Returns:\n",
    "        the function as represented by a 1-d numpy array of size N\n",
    "    \"\"\"\n",
    "\n",
    "    flip = np.random.random() > 0.5\n",
    "    f = np.ones(N) if flip else -1 * np.ones(N)\n",
    "    return f\n",
    "\n",
    "\n",
    "def choose_random_function() -> tuple[str, np.ndarray]:\n",
    "    \"\"\"Randomly choose a function from constant or balanced distributions.\n",
    "\n",
    "    Returns:\n",
    "        a tuple of the distribution (\"B\" or \"C\") and the function as an array.\n",
    "    \"\"\"\n",
    "    if np.random.rand() > 0.5:\n",
    "        f = gen_balanced_function(N)\n",
    "        dist = \"B\"\n",
    "    else:\n",
    "        f = gen_constant_function(N)\n",
    "        dist = \"C\"\n",
    "    return dist, f\n",
    "\n",
    "\n",
    "def randomized_alg(f: np.ndarray, sample_size: int) -> str:\n",
    "    \"\"\"Samples the function f from `sample_size` different inputs.\n",
    "\n",
    "    Queries the function f a number of times equal to sample_size.\n",
    "    If all the inputs are the same, then guess that the function\n",
    "    is constant.  If any inputs are different, then guess the function\n",
    "    is balanced.\n",
    "\n",
    "    Args:\n",
    "        f: the function to sample\n",
    "        sample_size: number of times to sample the function f\n",
    "\n",
    "    Returns:\n",
    "        a string representing the type of function, either\n",
    "            \"balanced\" or \"constant\"\n",
    "    \"\"\"\n",
    "    N = len(f)\n",
    "    sample_index = np.random.choice(N, size=sample_size)\n",
    "    if len(set(f[sample_index])) == 2:\n",
    "        return \"balanced\"\n",
    "    return \"constant\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Zo7zeqllWWNf"
   },
   "outputs": [],
   "source": [
    "N = 128  # size of the problem, n=7, N=2^7=128\n",
    "samples_size_per_function = 3\n",
    "number_of_functions_to_try = 1000\n",
    "\n",
    "\n",
    "res = pd.DataFrame()\n",
    "for _ in range(number_of_functions_to_try):\n",
    "    dist, f = choose_random_function()\n",
    "    decision = randomized_alg(f, samples_size_per_function)\n",
    "    res = pd.concat(\n",
    "        [res, pd.DataFrame({\"Distribution\": [dist], \"Decision\": [decision], \"Count\": [1]})],\n",
    "        ignore_index=True,\n",
    "    )\n",
    "confusion = res.pivot_table(index=\"Distribution\", columns=\"Decision\", values=\"Count\", aggfunc=\"sum\")\n",
    "# Translate the counts into percentage\n",
    "confusion.div(confusion.sum(axis=1), axis=0).apply(lambda x: round(x, 4) * 100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dU4e1_FrbME1"
   },
   "source": [
    "Recall that the B in BPP means *bounded-error*. Actually, you can achieve arbitary small error rate under the same query complexity order. Say, you set an acceptable error rate $\\epsilon$. Then the key is that we can run the algorithm multiple times. Repeat it as many times as you want until the error rate is lower than $\\epsilon$. It is crucial to note that the error rate for this particular problem does not depend on the size of the input but only on the size of the sample and the reptitation. For this reason, in order to get to a bounded probability error, it is sufficient to just adjust the sample size and/or repetitions to a given constant - which means that the \"query complexity\" of the algorithm will stay $O(1)$. For example, let's run previous algorithms 3 times and make the final decision based on the majority of the decision of each term. You should verify that the probability of error indeed decreased and independent of $N$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "uY_VUjr0bLS_"
   },
   "outputs": [],
   "source": [
    "N = 128\n",
    "samples_size_per_function = 3\n",
    "repetitions_of_randomized_alg = 3\n",
    "number_of_functions_to_try = 1000\n",
    "\n",
    "res = pd.DataFrame()\n",
    "for _ in range(number_of_functions_to_try):\n",
    "    dist, f = choose_random_function()\n",
    "    constant_minus_blanaced_count = 0\n",
    "    for _ in range(repetitions_of_randomized_alg):\n",
    "        decision = randomized_alg(f, samples_size_per_function)\n",
    "        constant_minus_blanaced_count += 1 if decision == \"constant\" else -1\n",
    "    final_decision = \"constant\" if constant_minus_blanaced_count > 0 else \"balanced\"\n",
    "    res = pd.concat(\n",
    "        [res, pd.DataFrame({\"Distribution\": [dist], \"Decision\": [final_decision], \"Count\": [1]})],\n",
    "        ignore_index=True,\n",
    "    )\n",
    "confusion = res.pivot_table(index=\"Distribution\", columns=\"Decision\", values=\"Count\", aggfunc=\"sum\")\n",
    "# Translate the counts into percentage\n",
    "confusion.div(confusion.sum(axis=1), axis=0).apply(lambda x: round(x, 4) * 100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Tda2vy2pb1aU"
   },
   "source": [
    "If you try the algorithm with more repetitions, you will observe the error rate decreases rapidly. You can also try different values of $K, N, $ and/or repetitions to see how the confusion matrix changes according.\n",
    "\n",
    "After you understand the concept of BPP, it is easy to understand [bounded-error quantum polynomial time (BQP)](https://en.wikipedia.org/wiki/BQP) now. BQP is the class of decision problems solvable by a quantum computer in polynomial time, with an error probability of at most 1/3 for all instances. It is the quantum analogue to the complexity class BPP. Actually, with a quantum computer, the previous problem can be solved using the[Deutsch–Jozsa algorithm](https://en.wikipedia.org/wiki/Deutsch%E2%80%93Jozsa_algorithm).  This algorithm utilizes a single query and is guaranteed to be correct always. The Fourier Checking problem that will be introduced later belongs to BQP as well."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "spyqzGyf5kGW"
   },
   "source": [
    "## Fourier Transform over $\\mathbb{Z}^n_2$\n",
    "In this colab, we are interested in the boolean function of the form $f : \\{0, 1\\}^n \\rightarrow\\{-1, 1\\}$. In this case, the Fourier transform of $f$ over $\\mathbb{Z}^n_2$ is defined as\n",
    "$$\n",
    "  \\hat{f}(y) := \\frac{1}{\\sqrt{N}} \\sum_{x\\in\\{0,1\\}^n} (-1)^{x \\cdot y} f\n",
    "  (x).\n",
    "$$\n",
    "where $x \\cdot y$ means the bit-wise inner product between $x$ and $y$. Note this is not the standard discrete Fourier transform definition over $\\mathbb{Z}_N$. According to the Parseval's identity, we have\n",
    "\n",
    "$$\n",
    "  \\sum_{x\\in\\{0,1\\}^n} f(x)^2 = \\sum_{y\\in\\{0,1\\}^n} \\hat{f}(y)^2  = N .\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ucEcJmSQ26vg"
   },
   "outputs": [],
   "source": [
    "def bitwise_dot(x: int, y: int) -> int:\n",
    "    \"\"\"Compute the dot product of two integers bitwise.\"\"\"\n",
    "    i = x & y\n",
    "\n",
    "    n = bin(i).count(\"1\")\n",
    "    return int(n % 2)\n",
    "\n",
    "\n",
    "def fourier_transform_over_z2(v: np.ndarray) -> np.ndarray:\n",
    "    \"\"\"Fourier transform function over z_2^n group.\n",
    "\n",
    "    Args:\n",
    "        v: an array with 2**n elements.\n",
    "\n",
    "    Returns:\n",
    "        vs: a numpy array with same length as input.\n",
    "    \"\"\"\n",
    "    N = len(v)\n",
    "    assert bin(N).count(\"1\") == 1, \"v must be a 2**n long vector\"\n",
    "    v_hat = np.array([0.0] * N)\n",
    "    for y in range(N):\n",
    "        for x in range(N):\n",
    "            v_hat[y] += ((-1) ** bitwise_dot(x, y)) * v[x]\n",
    "    return v_hat / np.sqrt(N)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "fqXHBoEm7ebP"
   },
   "source": [
    "Let's have some examples in $\\mathbb{Z}^2_2$. You should verify that both functions have same energy 4(as defined by Parseval's identity above).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "AueS6I2f4hAx"
   },
   "outputs": [],
   "source": [
    "f = np.array([1, -1, 1, -1])\n",
    "f_hat = fourier_transform_over_z2(f)\n",
    "print(f\"f: {list(f)} f_hat: {list(f_hat)}\")\n",
    "\n",
    "f = np.array([1, 1, 1, -1])\n",
    "f_hat = fourier_transform_over_z2(f)\n",
    "print(f\"f: {list(f)} f_hat: {list(f_hat)}\")\n",
    "\n",
    "f = np.array([1, -1, -1, 1])\n",
    "f_hat = fourier_transform_over_z2(f)\n",
    "print(f\"f: {list(f)} f_hat: {list(f_hat)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "02GAaXsAsYOW"
   },
   "source": [
    "## Forrelation\n",
    "\n",
    "The concept of *forrelation* is the key concept in the Fourier checking problem, which is a combination of the words \"Fourier Transform\" and \"Correlation\".\n",
    "Recall the classical correlation between two vectors $u$ and $v$ is defined as\n",
    "$$\n",
    "  {\\rm correlation}(u,v) = \\frac{ \\langle u, v \\rangle }{\\|u\\| \\|v\\|}.\n",
    "$$\n",
    "Then, the forrelation between two vectors $u$ and $v$ is just the correlation between $u$ and the Fourier trasformed of $v$ -- denoted as $\\widehat{v}$:\n",
    "\\begin{align}\n",
    "  {\\rm forrelation}(u, v) =& \\frac{\\langle u, \\widehat{v} \\rangle }{\\|u\\| \\|\\widehat{v}\\|},\\\\\n",
    "  =& \\frac{ \\langle u, \\widehat{v} \\rangle }{\\|u\\| \\|v\\|}.\n",
    "\\end{align}\n",
    "where the second equality is due to the Parseval's identity.\n",
    "Since in this tutorial we are interesed in Boolean function, we replace the arbitrary vector $u$ and $v$ by the output of Boolean function $f$ and $g$. Now we can further simplify the above definition:\n",
    "\n",
    "$$\n",
    "\\begin{align}\n",
    "  {\\rm forrelation}(f, g) =& \\frac{\\langle f, \\widehat{g} \\rangle }{\\|f\\| \\|g\\|}\\\\\n",
    "  =& \\frac{1}{N} \\langle f, \\widehat{g}\\rangle  \\\\\n",
    "  =& \\frac{1}{N} \\sum_{x \\in \\{0,1\\}^n}f(x)\\widehat{g}(x)\\\\\n",
    "  =& \\frac{1}{N^{3/2}} \\sum_{x, y \\in \\{0,1\\}^n}f(x)(-1)^{x \\cdot y}g(y)\n",
    "\\end{align}\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9r_wFSV6BT3s"
   },
   "outputs": [],
   "source": [
    "def get_correlation(f: np.ndarray, g: np.ndarray) -> np.ndarray:\n",
    "    \"\"\"Returns the classical correlation between two 1-d numpy arrays.\"\"\"\n",
    "    return f.dot(g) / np.linalg.norm(f) / np.linalg.norm(g)\n",
    "\n",
    "\n",
    "def get_forrelation(f: np.ndarray, g: np.ndarray) -> np.ndarray:\n",
    "    \"\"\"Returns the forrelation over Z^2 between two 1-d numpy arrays.\"\"\"\n",
    "    g_hat = fourier_transform_over_z2(g)\n",
    "    return f.dot(g_hat) / np.linalg.norm(f) / np.linalg.norm(g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "evFK1aCTBm1E"
   },
   "outputs": [],
   "source": [
    "# let's see some examples to gain some insights of forrelation\n",
    "f = np.array([1, -1, 1, -1])\n",
    "g = np.array([1, -1, 1, -1])\n",
    "print(f\"Correlation: {get_correlation(f,g)}  Forrelation: {get_forrelation(f,g)}\")\n",
    "\n",
    "f = np.array([1, 1, 1, -1])\n",
    "g = np.array([-1, -1, -1, 1])\n",
    "print(f\"Correlation: {get_correlation(f,g)}  Forrelation: {get_forrelation(f,g)}\")\n",
    "\n",
    "f = np.array([1, -1, -1, 1])\n",
    "g = np.array([1, 1, 1, 1])\n",
    "print(f\"Correlation: {get_correlation(f,g)}  Forrelation: {get_forrelation(f,g)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ozbeF_nZiEib"
   },
   "source": [
    "# Fourier Checking Problem\n",
    "\n",
    "Now we are prepared to discuss the *Fourier Checking* problem. Here we are given oracle access to two Boolean functions $f,g : \\{0, 1\\}^n \\rightarrow\\{-1, 1\\}$. We are promised that one of the following two cases is true:\n",
    "\n",
    "- $\\langle f, g \\rangle$ was drawn from the uniform distribution $\\mathcal{U}$.\n",
    "- $\\langle f, g \\rangle$ was drawn from the forrelated distribution $\\mathcal{F}$ (Will be explained in more details later).\n",
    "\n",
    "The problem is a decision problem that accepts the $\\langle f, g \\rangle$ if it was drawn from $\\mathcal{F}$ and rejects $\\langle f, g \\rangle$ if it was drawn from $\\mathcal{U}$.\n",
    "\n",
    "*Note: Since $\\mathcal{F}$ and $\\mathcal{U}$ overlap slightly, we can only hope to succeed with overwhelming probability over the choice of  $\\langle f, g \\rangle$ , not for every  $\\langle f, g \\rangle$  pair.*\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "hQMu28BGwZYS"
   },
   "source": [
    "## Generate functions $f$ and $g$ from distribution $\\mathcal U$ and $\\mathcal F$\n",
    "The forrelated distribution $\\mathcal{F}$ is generated as followings.\n",
    "First\n",
    "choose a random real vector $v=(v_x)_{x\\in\\{0, 1\\}^n} \\in \\mathbb{R}^N$, by drawing each entry independently from Gaussian distribution with zero-mean and variance 1. Then set $f(x) := {\\rm sgn}(v_x)$ and $g(x) := {\\rm sgn}(\\widehat{v}_x)$, where $\\widehat{v}_x$ the Fourier Transform of $v$ is\n",
    "$$\n",
    "  \\widehat{v}_y := \\frac{1}{\\sqrt{N}} \\sum_{x\\in\\{0,1\\}^n} (-1)^{x\\cdot y}v_x,\n",
    "$$\n",
    "and \n",
    "$$\n",
    "  {\\rm sgn}(\\alpha) := \\left\\{\n",
    "  \\begin{aligned}\n",
    "    1 \\;\\;\\; &{\\rm if}\\; \\alpha \\geq 0 \\\\\n",
    "    -1 \\;\\;\\;&{\\rm if}\\; \\alpha < 0\n",
    "  \\end{aligned}\\right.\n",
    "$$\n",
    "Notice, $f$ and $g$ *individually* are still uniformly random, but they are no longer independent. Now $f$ is forrelated with $g$.\n",
    "For simplicity, we only consider the *PROMISE FOURIER CHECKING* problem. Under this situation, we are promised that the quatity:\n",
    "\n",
    "\\begin{align}\n",
    "  p(f,g) := \\left[{\\rm forrelation}(f, g)\\right]^2 = \\frac{1}{N^3} \\left(\\sum_{x, y \\in \\{0,1\\}^n}f(x)(-1)^{x \\cdot y}g(y)\\right)^2\n",
    "\\end{align}\n",
    "\n",
    "is either at least 0.05 or at most 0.01."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Ow6khbkKebag"
   },
   "outputs": [],
   "source": [
    "def draw_two_distribution_from_f_set(N: int) -> tuple[np.ndarray, np.ndarray, float, float]:\n",
    "    \"\"\"Samples two distributions from the 'F' set above.\n",
    "\n",
    "    Uses a while loop to guarantee a forrelated pair \"as promised\".\n",
    "\n",
    "    Returns:\n",
    "        A tuple that contains the two distributions, and the correlation/forrelation.\n",
    "    \"\"\"\n",
    "    sgn = lambda x: 1 if x >= 0 else -1\n",
    "    forrelation = 0.2\n",
    "    while (abs(forrelation) ** 2 < 0.05) and (abs(forrelation) ** 2 > 0.01):\n",
    "        vs = np.array([np.random.normal() for _ in range(N)])\n",
    "        vs_hat = fourier_transform_over_z2(vs)\n",
    "        fs = np.array([sgn(v) for v in vs])\n",
    "        gs = np.array([sgn(v_hat) for v_hat in vs_hat])\n",
    "        forrelation = get_forrelation(fs, gs)\n",
    "        correlation = get_correlation(fs, gs)\n",
    "    return fs, gs, forrelation, correlation\n",
    "\n",
    "\n",
    "def draw_two_distribution_from_u_set(N: int) -> tuple[np.ndarray, np.ndarray, float, float]:\n",
    "    \"\"\"Samples two distributions from the 'U' set above.\n",
    "\n",
    "    Uses a while loop to guarantee a forrelated pair \"as promised\".\n",
    "\n",
    "    Returns:\n",
    "        A tuple that contains the two distributions, and the correlation/forrelation.\n",
    "    \"\"\"\n",
    "    sgn = lambda x: 1 if x >= 0 else -1\n",
    "    forrelation = 0.2\n",
    "    while (abs(forrelation) ** 2 < 0.05) and (abs(forrelation) ** 2 > 0.01):\n",
    "        vs = np.array([np.random.normal() for _ in range(N)])\n",
    "        fs = np.array([sgn(v) for v in vs])\n",
    "        us = np.array([np.random.normal() for _ in range(N)])\n",
    "        gs = np.array([sgn(u) for u in us])\n",
    "        forrelation = get_forrelation(fs, gs)\n",
    "        correlation = get_correlation(fs, gs)\n",
    "    return fs, gs, forrelation, correlation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "bVFeIs6-RrfP"
   },
   "outputs": [],
   "source": [
    "n = 6\n",
    "N = 2**n\n",
    "\n",
    "fs, gs, forrelation, correlation = draw_two_distribution_from_f_set(N)\n",
    "print('Correlation and forrelation from F set')\n",
    "print(f\"fs: {list(fs)}\")\n",
    "print(f\"gs: {list(gs)}\")\n",
    "print(f'Correlation: {correlation} Forrelation: {forrelation}')\n",
    "plt.figure(figsize=(15, 5))\n",
    "plt.stem(fs)\n",
    "plt.stem(gs, linefmt='--r', markerfmt='ro')\n",
    "plt.title(f\"Two distributions from F set\")\n",
    "\n",
    "print('')\n",
    "print('Correlation and forrelation from U set')\n",
    "fs, gs, forrelation, correlation = draw_two_distribution_from_u_set(N)\n",
    "print(f\"fs: {list(fs)}\")\n",
    "print(f\"gs: {list(gs)}\")\n",
    "print(f'Correlation: {correlation} Forrelation: {forrelation}')\n",
    "\n",
    "plt.figure(figsize=(15, 5))\n",
    "plt.stem(fs)\n",
    "plt.stem(gs, linefmt='--r', markerfmt='ro')\n",
    "_ = plt.title(f\"Two distributions from U set\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "TZwIP2_337WH"
   },
   "source": [
    "Typically, $\\mathcal{U}$ and $\\mathcal{F}$ is not obviously different from each other even after we plot the whole functions information together. However, it is not hard to show that Fourier Checking is in BQP: basically, one can prepare a uniform superposition over all $x\\in\\{0,1\\}^n$, then query $f$, apply a quantum Fourier transform, query $g$, and\n",
    "check whether one has recovered something close to the uniform superposition. On the other hand, being forrelated seems like an extremely “global” property of $f$ and $g$: one that would not be apparent from querying any small number of $f$ and $g$ values, regardless of the outcomes of those queries."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "obqAoSrNRFJ9"
   },
   "source": [
    "# Quantum Algorithm for Fourier Checking\n",
    "\n",
    "Now, we present the quantum algorithm for Fourier Checking, which is quite simple actually.\n",
    "It can achieve the constant error probability with $O(1)$ query complexity.\n",
    "\n",
    "**Algorithm Description:**\n",
    "\n",
    "First, we prepare a uniform superposition over all $x \\in \\{0, 1\\}^n$. Then query $f$ in superposition, to create the state\n",
    "$$\n",
    "  \\frac{1}{\\sqrt{N}} \\sum_{x \\in \\{0, 1\\}^n} f(x) |x\\rangle\n",
    "$$\n",
    "Applying Hadmard gates to all $n$ qubits, to create the state\n",
    "$$\n",
    "  \\frac{1}{N} \\sum_{x,y  \\in \\{0, 1\\}^n} f(x) (-1)^{x\\cdot y} |y\\rangle\n",
    "$$\n",
    "Then query $g$ in superposition, to create the state\n",
    "$$\n",
    "  \\frac{1}{N} \\sum_{x,y  \\in \\{0, 1\\}^n} f(x) (-1)^{x\\cdot y}g(y) |y\\rangle\n",
    "$$\n",
    "Then apply Hadmard gates to all $n$ qubits again, to create the state\n",
    "$$\n",
    "  \\frac{1}{N^{3/2}} \\sum_{x,y  \\in \\{0, 1\\}^n} f(x) (-1)^{x\\cdot y} g(y) (-1)^{y \\cdot z}|z\\rangle\n",
    "$$\n",
    "Finally, measure in the computational basis, and \"accept\" if and only if the outcome $|0\\rangle^{\\otimes n}$is observed.\n",
    "\n",
    "If needed, repeat the whole algorithm $O(1)$ times to boost the success probability."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ddJVURrobh8g"
   },
   "source": [
    "It is clear that the probability of observation $|0\\rangle^{\\otimes n}$ equals the quantity:\n",
    "\\begin{align}\n",
    "  p(f,g) := \\frac{1}{N^3} \\left(\\sum_{x, y \\in \\{0,1\\}^n}f(x)(-1)^{x \\cdot y}g(y)\\right)^2\n",
    "\\end{align}\n",
    "\n",
    "It is shown in [[2]](https://arxiv.org/pdf/0910.4698.pdf) that\n",
    "\n",
    "$$\n",
    "\\begin{align}\n",
    "   {\\rm Pr}_{\\langle f, g \\rangle \\sim \\mathcal{U}} [p(f,g) \\geq 0.01] \\leq& \\; \\frac{100}{N} \\\\\n",
    "   {\\rm Pr}_{\\langle f, g \\rangle \\sim \\mathcal{F}} [p(f,g) \\geq 0.05] \\geq&\\;\\frac{1}{50} \\\\\n",
    "\\end{align}\n",
    "$$\n",
    "\n",
    "This implies that the probability of the forrelation square between two functions drawing from uniform  distribution $\\mathcal{U}$ having larger than 0.01  will decaying quickly when we have more number of qubits. Hence, the Promise Fourier Checking problem can be solved through simply accepting when $p(f,g) \\geq 0.05$ and rejecting when $p(f,g) \\leq 0.01$ with constant error probability, using $O(1)$ queries to $f$ and $g$."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FDVxr7UKze2X"
   },
   "source": [
    "## Implementation in Cirq\n",
    "\n",
    "Above algorithm is simple and straightforward to implement. It only involves with Hadmard gates and function oracles. Implementation of oracles based on truth table in Cirq is just simple diagnal gates. To see that, let's use a oracle defined over $\\mathbb{Z}_2^2$ as example:\n",
    "\n",
    "$$\n",
    "\\begin{align}\n",
    "  \\sum_{x\\in \\{0, 1\\}^2} f(x)|x\\rangle\n",
    "  =&\\frac{1}{2}\\Big(f(0,0)|00\\rangle + f(0,1)|01\\rangle+f(1,0)|10\\rangle + f(1,1) |11\\rangle\\Big)\\\\\n",
    "  =&\\;\\;\\left[ \\begin{array}{cccc}\n",
    "    f(0,0)   & & & \\\\ \n",
    "    &  f(1,0)  & & \\\\\n",
    "    &   &  f(1,0)  & \\\\\n",
    "    &  & &   f(1,1) \\\\\n",
    "  \\end{array} \\right]\n",
    "  \\left[ \\begin{array}{c}\n",
    "  1/2\\\\\n",
    "  1/2\\\\\n",
    "  1/2\\\\\n",
    "  1/2\n",
    "  \\end{array} \\right]\n",
    "\\end{align}\n",
    "$$\n",
    "\n",
    "It is crucial to note that the output of $f$ is either 1 or -1, so the diagonal matrix is unitary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "_k4-gc7SCBvS"
   },
   "outputs": [],
   "source": [
    "def oracle(fs: np.ndarray, qubits: Sequence[cirq.Qid]) -> cirq.Operation:\n",
    "    \"\"\"Construct a sample oracle using a function as above.\n",
    "\n",
    "    This will create an operation with a unitary matrix that is diagonal\n",
    "    and whose entries correspond to the values of the input function 'fs'.\n",
    "    \"\"\"\n",
    "    return cirq.MatrixGate(np.diag(fs).astype(complex))(*qubits)\n",
    "\n",
    "\n",
    "def fourier_checking_algorithm(qubits, fs, gs):\n",
    "    \"\"\"Returns the circuit for Fourier Checking algorithm given an input.\"\"\"\n",
    "    yield cirq.parallel_gate_op(cirq.H, *qubits)\n",
    "    yield oracle(fs, qubits)\n",
    "    yield cirq.parallel_gate_op(cirq.H, *qubits)\n",
    "    yield oracle(gs, qubits)\n",
    "    yield cirq.parallel_gate_op(cirq.H, *qubits)\n",
    "    yield cirq.measure(*qubits)\n",
    "\n",
    "\n",
    "qubits = cirq.LineQubit.range(n)\n",
    "fs, gs, forrelation, correlation = draw_two_distribution_from_f_set(N)\n",
    "circuit = cirq.Circuit(fourier_checking_algorithm(qubits, fs, gs))\n",
    "print(circuit)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "cD7B8h5nA75O"
   },
   "source": [
    "We derived that the square forrelation between $f$ and $g$ is the same as the state of final state in circuit so we can use Cirq to check it. Just remember the final state is *never* able to be obtained in reality. In simulation, it is completely doable through `final_state_vector` or the `dirac_notation` of each moment step: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Y5nCR6jiU7Up"
   },
   "outputs": [],
   "source": [
    "assert np.isclose(\n",
    "    circuit.final_state_vector(ignore_terminal_measurements=True, dtype=np.complex64)[0],\n",
    "    forrelation,\n",
    ")\n",
    "\n",
    "s = cirq.Simulator()\n",
    "for step in s.simulate_moment_steps(circuit):\n",
    "    print(step.dirac_notation())\n",
    "    print(\"|0> state probability to observe: \", np.abs(step.state_vector(copy=True)[0]) ** 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9rEqwrI9KM4w"
   },
   "outputs": [],
   "source": [
    "final_state = circuit.final_state_vector(ignore_terminal_measurements=True, dtype=np.complex64)\n",
    "plt.fill_between(np.arange(len(final_state)), np.abs(final_state) ** 2)\n",
    "plt.xlabel(\"State of qubits\")\n",
    "plt.ylabel(\"Probability\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "m5Qzdo8rTPxJ"
   },
   "source": [
    "In reality, we can measure the state of qubits only. Each measurement will only produce one state. In order to estimate the probability, we can do 100 repetitions and use the frequency of the 0 state as the approximation of its probability."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "RS-njXvxM3pA"
   },
   "outputs": [],
   "source": [
    "repetitions = 100\n",
    "obs = s.run(circuit, repetitions=repetitions)\n",
    "qubits_name = ','.join(str(q) for q in qubits)\n",
    "times_zero_was_measured = len(obs.data[obs.data[qubits_name] == 0])\n",
    "print(\n",
    "    f\"times zero state was measured from {repetitions} measurements:\"\n",
    "    + f\"{times_zero_was_measured} - {float(times_zero_was_measured/repetitions)*100}%\"\n",
    ")\n",
    "if float(times_zero_was_measured / repetitions) > 0.05:\n",
    "    print(\"fs and gs is forrelated!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "IHLLADQvT-lK"
   },
   "source": [
    "Last, we can randomly draw the functions from either $\\mathcal{U}$ or $\\mathcal{F}$ set to evaluate the confusion matrix of the Fourier Checking algorithm. With the confusion matrix, you should be confident that quantum algorithm instead can solve the Fourier Checking in $O(1)$ time -- even though we need to measure 100 or 1000 times to evaluate the probability, it is irrelevant to the number of states or qubits."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "PTq4ys9PanaM"
   },
   "outputs": [],
   "source": [
    "res = pd.DataFrame()\n",
    "repetitions = 100\n",
    "num_rounds = 1000\n",
    "for _ in range(num_rounds):\n",
    "    if np.random.rand() > 0.5:\n",
    "        fs, gs, _, _ = draw_two_distribution_from_f_set(N)\n",
    "        source = \"F set\"\n",
    "    else:\n",
    "        fs, gs, _, _ = draw_two_distribution_from_u_set(N)\n",
    "        source = \"U set\"\n",
    "\n",
    "    circuit = cirq.Circuit(fourier_checking_algorithm(qubits, fs, gs))\n",
    "    obs = s.run(circuit, repetitions=repetitions)\n",
    "    times_zero_was_measured = len(obs.data[obs.data[qubits_name] == 0])\n",
    "    decision = \"accept\" if times_zero_was_measured / repetitions > 0.05 else \"reject\"\n",
    "    res = pd.concat(\n",
    "        [res, pd.DataFrame({\"Source\": [source], \"Decision\": [decision], \"Count\": [1]})],\n",
    "        ignore_index=True,\n",
    "    )\n",
    "confusion = res.pivot_table(index=\"Source\", columns=\"Decision\", values=\"Count\", aggfunc=\"sum\")\n",
    "# Translate the counts into percentage\n",
    "confusion.div(confusion.sum(axis=1), axis=0).apply(lambda x: round(x, 4) * 100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "03GsjDCXst4E"
   },
   "source": [
    "# Query Complexity of Quantum Algorithm Versus the Classical one\n",
    "\n",
    "After we present the quantum algorithm, it is time to look back to classical algorithm for comparison. It is not hard to give a classical algorithm that solves Fourier Checking using $O(\\sqrt{N}) = O(2^{n/2})$ queries, which is similar as we did in the BPP section.\n",
    "\n",
    "For some $K=\\Theta(\\sqrt{N})$, first choose sets $X=\\{x_1,\\ldots, x_K\\}$ and $Y = \\{y_1, \\ldots, y_K\\}$ of $n-$bit strings uniformly at random. Then query $f(x_i)$ and $g(y_i)$ for all $i \\in [K]$. Finally, compute\n",
    "$$\n",
    "  Z := \\sum_{i,j=1}^K f(x_i) (-1)^{x_i\\cdot y_j} g(y_j)\n",
    "$$\n",
    "accept if $|Z|$ is greater than some cutoff $cK$, and reject otherwise. For suitable $K$ and $c$, one can show that this algorithm accepts a forrelated $\\langle f, g \\rangle$ pair with probability at least $2/3$. Comparing this with $O(1)$ complexity in quantum query, we can see an exponential speedup by using quantum computer.\n",
    "\n",
    "<!-- Maybe we can find a better classical algorithm for forrelation problem, but it is proved that the lower bound on the classical query complexity is at least $\\Omega\\left(\\sqrt[4]{N}\\right) = \\Omega\\left(2^{n/4}\\right)$. -->\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HpOFrK4QW5Y4"
   },
   "source": [
    "# Further reading\n",
    "\n",
    "The Forrelation problem was originally introduced in [[2]](https://arxiv.org/pdf/0910.4698.pdf). Later, a **$k$-fold Forrelation** problem was introduced in [[3]](https://arxiv.org/pdf/1411.5729.pdf), which considered the forrelation between $k$ oracle functions. In that paper, it also improved the proof about the separation of forrelation problem from $\\Omega\\left(\\sqrt[4]{N}\\right)$ to $\\tilde{\\Omega}\\left(\\sqrt{N}\\right)$, which can not be further improved. Thus, resolving an open question of Buhrman et al., there is NO partial\n",
    "Boolean function whose quantum query complexity is constant and whose randomized query complexity is linear. \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "92aaa99fa788"
   },
   "source": [
    "# References\n",
    "\n",
    "[[1]](https://www.sciencedirect.com/science/article/pii/S030439750100144X) Harry Buhrman and Ronald de Wolf, \"Complexity measures and decision tree complexity: a survey\" Theoretical Computer Science 288, no. 1 (2002): 21-43.\n",
    "\n",
    "[[2]](https://arxiv.org/pdf/0910.4698.pdf) Scott Aaronson, \"BQP and the Polynomial Hierarchy\", STOC ’10, page 141–150, New York, NY, USA, 2010.\n",
    "\n",
    "[[3]](https://arxiv.org/pdf/1411.5729.pdf) Scott Aaronson and Andris Ambainis, \"Forrelation: A problem that optimally separates quantum from classical computing\", SIAM J. Comput. 47, no. 3 (2018): 982–1038.\n"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "fourier_checking.ipynb",
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
